FedGEN:面向异质联邦学习的无数据知识蒸馏框架
背景与引言:联邦学习(FL)是一种去中心化的机器学习范例,多方协作学习全局模型而无需访问各自的隐私数据。但是用户异构性给FL带来了巨大的挑战,这可能会导致FL全局模型漂移、收敛速度缓慢等问题。现有方法主要从两个互补的角度解决用户异质性:1)侧重于通过在参数空间上调节局部模型与全局模型的偏差(客户端层面)从而稳定局部训练;2)通过提高模型聚合的效率(服务器端层面)来缓解用户异质性。
最近有工作利用知识蒸馏(KD)来解决联邦学习中的用户异构性问题的想法,具体是通过使用来自异构用户的聚合知识优化全局模型,而不直接聚合用户的模型参数。然而这种方法却很依赖于代理数据集(proxy dataset),如果没有恰当的代理数据集,该方法便是不切实际的。此外,来自各客户端参与方的集成知识也没有被充分利用来指导本地模型的训练,这有可能反过来会影响全局聚合模型的性能。
基于上述挑战,本文提出了一种基于data-free方式的知识蒸馏框架来解决FL中的异构性问题,该方法被称为FedGEN:服务器学习一个轻量级生成器,基于data-free方式集成各参与方用户的信息(特征表示),然后再将这些集成信息(包含来自其它参与方客户端特征的提炼知识)广播给参与用户,本地用户使用学习到的知识作为归纳偏置从而调节局部客户端训练(归纳偏置就是基于先验知识对目标模型的判断,将无限可能的目标函数约束在一个有限的假设类别之中;额外说下自己对于归纳偏置的一些理解:模型对应假设空间中的一个假设,如果与训练集一致的假设有多个,那么应该选择哪一个模型呢?归纳偏好可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或者价值观,即归纳出一定的规则,然后对模型做一定的约束)。
FedGEN基于用户集成信息在服务器上学习轻量级生成器,具有如下好处:1) 从参与用户中提取在模型平均之后原本要被减轻的知识,并且这些知识不依赖于任何外部数据;2) FedGEN使用提取的知识直接调节局部模型更新,来自用户的集成知识对局部模型施加了归纳偏差,因此具有更好的泛化性能;3) 对于更具挑战性的FL场景(例如特征漂移、动态FL等),FedGEN只需要局部模型的预测层来进行知识提取,因此更加方面易用,此外由于FedGEN所学习的生成器是轻量级的,只给当前FL系统带来较小开销。
Paper: http://proceedings.mlr.press/v139/zhu21b/zhu21b.pdf
Code: https://github.com/zhuangdizhu/FedGen
问题定义
在本文中,主要讨论用于监督学习的典型FL设置,即多类分类的一般问题:设定X∈R_p为样本空间,Z∈R_d为潜在的特征空间,T代表由X的数据分布D和真值标签函数(X—>y)组成的域,域T可以理解为用户端的任务T,模型参数为θ:= [θf ; θp]包含两个部分:一个是由θf参数化的特征提取器f:X—>Z,另一个是由θp参数化的预测函数h:Z—>△y。给定一个非负凸损失函数L:△y×y—>R,那么由θ参数化的模型在域T上的损失LT(θ)定义如下:
联邦学习FL旨在学习一个由θ参数化的全局模型,以最大限度地降低全局模型在每个用户任务Tk上的经验风险损失:
知识蒸馏KD是基于Teacher-Student模式,其目的是通过从一个或多个强大的Teacher模型提取的知识来引导学习一个轻量的Student模型,从而减轻能耗或提高性能。基于FL的知识蒸馏一种典型的方法是利用proxy dataset来最小化来自Techer模型 θ_T和Student模型 θ_s对应的logits输出(即最终的全连接层的输出,代表模型预测的置信度)。目前已有将知识蒸馏应用于FL的方法:该方法将用户模型θ_k作为Teacher,全局模型θ作为Student,然后通过聚合θ_k的信息来提高模型泛化性能,其优化目标如下所示:
然而,上述方法的一个主要限制在于其对proxy dataset的依赖,因此proxy dataset的选择在蒸馏过程中对模型性能起到关键作用。为应对这一限制,FedGEN提出以data-free的方式将知识蒸馏应用于FL。
FedGEN方法
FedGEN核心思想是提取关于数据分布的全局视图的知识(这些知识是传统FL无法观察到的来自全局的集成知识),并将这些知识提取到本地模型中以指导他们的学习。关于FedGEN学习过程的概述如下图1所示。
1)Knowledge Extraction:在这里,首先考虑学习一个条件分布Q:y—>x来描述来自参与方客户端的全局视图的集成知识:
其中p(y)和p(y|x) 分别是样本真值标签的先验分布和后验分布,两者都是未知的。但这一条件分布Q,在高维度的样本空间X的情况下将导致计算过载,并可能造成用户相关数据配置的泄露。因此,文章提出另一个可行的想法:直接恢复潜在空间上的诱导分布G:Y—>Z(潜在空间也就是全局数据的特征分布,比原始数据空间更加紧凑,并且可以减少某些隐私泄露的风险),具体如下:
基于上述推导,进一步地,FedGEN旨在通过学习一个条件生成器G来执行知识提取(恢复一个潜在空间上的特征),该生成器由w参数化,在服务端进行优化,优化目标如下:
其中g是logit输出,σ是激活函数。对于任意标签y,优化(4)只需要基于用户模型的预测模块θp_k。具体来说,为了使G(·| y)的输出多样化,FedGEN向生成器G_w引入了一个噪声向量ε ~ N(0, i),这样一来,给定任意的目标标签y,生成器G_w可以产生特征表示z~G_w(· | y),而该特征表示在潜在特征空间中可以根据所有用户模型的标签引导出理想的预测。
2)Knowledge Distillation:服务器将完成学习的生成器G_w广播给各个用户,使每个用户模型都可以从G_w中进行采样以获得特征空间上的增广表示:z~G_w(· | y),用户模型θ_k的优化目标也变为最大化其为增广样本产生理想预测的概率(个人理解就是本地模型训练原始数据 + 增广样本,使得损失最小):
通过生成器生成增强的样本可以向本地用户引入归纳偏差,以更好的泛化性能来增强本地模型训练学习。经过以上步骤,FedGEN通过交互学习出一个依赖于用户模型预测规则的轻量级生成器G_w,并利用该生成器向用户传递一致的知识,从而实现了data-free的知识提炼-蒸馏过程。
3)Extensions for Flexible Parameter Sharing:对于具有挑战性的FL场景:一方面,其具有深度特征提取层的高级网络通常包含数百万个参数,这给通信带来了巨大的负担;另一方面,基于FL的后门攻击方法已被证明是可行的,共享整个模型参数可能会带来相当大的隐私风险。FedGEN通过仅共享用户模型的预测层θp_k(即,生成器中标签到特征的函数),θf_k(特征提取器)则保留在用户本地,与共享整个模型的策略相比,这种部分共享模式更有效,同时更不容易发生数据泄漏。
关于FedGEN整体伪代码如下所示:
进一步理解FedGEN
知识蒸馏与归纳偏置:轻量级生成器G_w基于用户模型的预测规则进行学习,目的是融合来自用户模型的聚合信息来估计全局数据分布R(x|y),接着用户从G_w(x|y) 进行采样,采样结果作为自身的归纳偏置,进而调整决策边界,如下图3所示,在知识蒸馏KD之后,一个用户的准确率从81.2% 提高到了98.4%。
知识蒸馏与分布匹配:FedGEN和先前研究的主要区别在于:知识被蒸馏至用户模型,而不是全局模型。因此,蒸馏出来的知识(向用户传递的归纳偏置)可以通过在潜在空间Z上进行分布匹配,直接调节本地用户的学习。
知识蒸馏与泛化性能:用户异构性越高的话,则用户之间的数据分布差异也越高,这会降低全局模型的质量;而通过向本地用户提供与全局分布一致的增广数据,则可以提高泛化性能。
实验部分
实验设置:将最后一层作为预测器θpk,之前所有层作为特征提取器θfk。生成器Gw是基于MLP的。它以一个噪声向量ε和一个one-hot标签向量y作为输入,在经过一个维数为dh的隐藏层之后,输出一个维数为d的特征表示。对于MNIST和EMNIST数据集,建模non-iid数据分布,其中较小的α表示较高的数据异质性,因为它使得pk(y)分布对用户k更有偏差。对于CELEBA,原始数据自然是non-iid分布式的。然后通过将属于不同名人的图片聚合到不相交的组中,每个组分配给一个用户,进一步提高了数据的异质性。
在整体的实验部分,除非另有说明,否则都是进行了200轮全局迭代,总共有20个用户模型,活跃用户比率设置为r=50%。在本地局部训练过程,采用局部更新步骤T=20,每一步迭代使用大小为B=32的小批量数据。实验过程中,最多使用50%的总训练数据集进行模型训练,将其分发给用户模型,使用所有测试数据集进行性能评估。对于分类器基准网络,只将最后的MLP层视为预测器p_k,将所有之前的层视为特征提取器f_k,同时生成器G_w是基于MLP的:它以一个噪声向量ε和一个one-hot标签向量作为输入,基于一个维度为h的隐藏层之后,输出维度为d的特征表示。此外,实验中还基于Dirichlet分布来构造异构性non-iid数据分布,其中较小的参数α表示较高数据异质性。
性能概述:文章对FedGEN的多项性能进行了详细的实验分析:预测精度、学习效率、对游离用户的敏感度、对不同模型结构的敏感度、对通信延迟的敏感度、对生成器的网络结构和采样大小的敏感度、灵活参数共享的扩展性分析,以及数据集MNIST在迪利克雷分布下,基于参数α的不同取值下的异构性可视化。如下表1所示,我们可以观察到FedGEN性能显著优于其他基线方法。
数据异质性的影响:FedGEN是一种对不同级别的用户异质性具有鲁棒性同时始终表现良好的算法。如下图5所示,当数据分布高度异构时,FedGEN的增益更为显著,这也验证了本文的动机:FedGEN的优势是从本地用户中提取知识并集成归纳出来的,这缓解了用户之间潜在分布的差异性,然而FedAvg/FedProx等算法无法获取此类知识。
学习效率:如下图6所示,FedGEN具有最快的学习曲线来达到性能,并且优于其他基线方法。由于代理数据的优势,FedGEN方法可以直接使每个本地用户从主动学习到的知识受益,其效果更加明确和一致。
掉队用户的影响:基于CELEBA数据集探索总用户和活跃用户的不同数量,活跃比率r从0.2到0.9。从图5(b)可以看出,当掉队用户数较高时,FedGEN方法对于渐近性能仍然一致地优于其他所有基线方法。结合图6(a)和图6(b),可以观察到FedGEN方法需要更少的通信轮次来达到高性能,而不需要考虑掉队用户的设置带来的影响。
不同网络架构的影响:实验中使用CNN和MLP两种网络架构对MNIST数据集进行分析,如图5(c)和图5(d)所示,在两种不同的网络设置中,FedGEN方法,尽管使用CNN网络训练的整体性能明显高于使用MLP网络的整体性能,但是其出色性能是一致的。
生成器网络结构和采样大小的影响:如下表3所示的扩展分析,已经验证了FedGEN在不同的生成器网络架构上的健壮性。此外,从生成器中采样合成数据只会给本地用户增加少量的训练工作量。基于不同的合成样本大小,FedGEN的增益始终是显著优于其他方法的,并且足够数量的合成样本也会带来更好的性能(如下图7所示)。特别是在下表3中,在保持输出层维度固定的情况下,探索了生成器的输入噪声和隐藏层的不同维度之间的关系。FedGEN卓越性能验证了在non-iid这种具有挑战性的情况下无数据蒸馏的功效,FedGEN通过快速收敛和灵活的参数共享策略,也具有进一步减少通信工作量的潜力。
归纳总结
FedGEN通过聚合所有客户端模型的知识(标签信息)用来得到一个生成器模型,生成器可以根据标签Y生成特征Z,服务器将生成器广播给所有客户端,客户端通过生成器生成增广样本用来帮助本地模型训练(增广样本具有归纳偏置信息),通过生成器可以提炼出全局分布数据的知识给客户端,从而实现无信息的知识蒸馏。
代码解析
聚合所有客户端模型的知识训练生成器 代码位置:servers/serverpFedGen.py内train_generator函数 生成器模型定义:
其中passed_dataset就是代表数据集,因为不同的数据集对应不同的embedding以及layer因此需要加入这个参数,接下来我们看Generator这个生成器类:
就是叠加堆砌FC+BN+ReLU,最后生成特征,维度为latent_dim 然后我们看update_generator_函数:训练生成器
其中eps可以理解为生成器按照noise得出的结果,gen_output可以认为根据labels标签得到的结果,然后得出diversity_loss,根据diversity_loss训练生成器,其中diversity_loss定义如下:
主要就是计算l1,l2 loss以及cosine distance。
teacher_loss指的是分类loss,student_loss指的是生成器特征得到的标签与真实标签的KL Loss,最后就是diversity_loss训练生成器。通过以上三种loss最终得到生成器,然后服务器将生成器广播给所有客户端,客户端通过生成器生成增广样本用来帮助本地模型训练。
作者简介:
huan,西安电子科技大学硕士在读,计算机科学与技术专业。研究方向为联邦学习、模型压缩。往期推荐